from utils.util_calculate_psnr_ssim import *
import numpy as np
import torch
import torch.nn.functional as F

def test_backbone(model,device_name, testset,iter_time,test_loader):
    model.training = False
    model.eval()
    
    device = torch.device(device_name)
    
    avg_psnr, avg_ssim, avg_psnrb = 0,0,0
    sample_number = 0
    for i, test_data in enumerate(test_loader):
        img_lq, x_gt,img_name = test_data["L"].to(device), test_data["H"].to(device),test_data["H_path"][0]
        #x_gt = x_gt.flatten(0, 1)
        sample_number += 1
        
        # Pad the input if not_multiple_of 8
        img_multiple_of=8
        height,width = img_lq.shape[2], img_lq.shape[3]
        H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-height if height%img_multiple_of!=0 else 0
        padw = W-width if width%img_multiple_of!=0 else 0
        img_lq = F.pad(img_lq, (0,padw,0,padh), 'reflect')

        with torch.no_grad():
            output = model(img_lq)

        output = output[:, :, :height, :width]

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = (output * 255.0).round().astype(np.uint8)
        x_gt = x_gt.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        x_gt = (x_gt * 255.0).round().astype(np.uint8)
        #print(output.shape,x_gt.shape)

        psnr = calculate_psnr(output, x_gt, crop_border=0)
        #ssim = calculate_ssim(output, x_gt, crop_border=0)
        ssim = calculate_ssim(output, x_gt, crop_border=0, input_order = 'CHW')
        
        psnrb = 0

        criterion_info ="img name {},testset {}, iter_time {}, psnr/ssim/psnrb is {}/{}/{}".format(img_name,testset,iter_time,psnr,ssim,psnrb)
        print(criterion_info)


        #print(criterion_info)
        avg_psnr += psnr
        avg_ssim += ssim
        #avg_lpips+=lpips_
        avg_psnrb += psnrb
    avg_psnr = round(avg_psnr / sample_number,2)
    avg_ssim = round(avg_ssim / sample_number,4)
    avg_psnrb = avg_psnrb / sample_number

    epoch_criterion_info="testset {}, iter_time {}, avg psnr/ssim/psnrb is {}/{}/{}".format(testset,iter_time,avg_psnr,avg_ssim,avg_psnrb)
    print(epoch_criterion_info)